import json
import base64
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
import google.generativeai as genai
from PIL import Image
import torch
from diffusers import StableDiffusionPipeline
import io


class CheckpointDescriptionExtractor:
    """
    Extract visual descriptions from checkpoints using VLM (Vision Language Model).
    This runs when adding new checkpoints to generate their descriptions.
    """
    
    def __init__(self, model_name: str = "gemini-2.5-flash-exp"):
        """Initialize Gemini for visual description extraction."""
        genai.configure(api_key="YOUR_GEMINI_API_KEY")
        self.model = genai.GenerativeModel(model_name)
        
    def generate_sample_images(
        self, 
        checkpoint_path: str,
        num_samples: int = 4
    ) -> List[Image.Image]:
        """
        Generate sample images using the checkpoint to extract visual characteristics.
        """
        # Load the checkpoint with Stable Diffusion
        pipe = StableDiffusionPipeline.from_single_file(
            checkpoint_path,
            torch_dtype=torch.float16
        )
        pipe = pipe.to("cuda")
        
        # Generate diverse samples to capture the checkpoint's style
        sample_prompts = [
            "portrait photo",
            "full body standing",
            "close-up face",
            "action pose"
        ]
        
        images = []
        for prompt in sample_prompts[:num_samples]:
            # Use the trigger token if known, otherwise generic
            image = pipe(
                prompt=prompt,
                num_inference_steps=20,
                guidance_scale=7.5
            ).images[0]
            images.append(image)
        
        # Clean up GPU memory
        del pipe
        torch.cuda.empty_cache()
        
        return images
    
    def extract_description_from_images(
        self,
        images: List[Image.Image],
        subject_type: str = "subject"
    ) -> str:
        """
        Use VLM to extract visual description from generated images.
        """
        # Convert images to base64 for Gemini
        image_parts = []
        for img in images:
            buffered = io.BytesIO()
            img.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_parts.append({
                "inline_data": {
                    "mime_type": "image/png",
                    "data": img_base64
                }
            })
        
        prompt = f"""Analyze these images showing the same {subject_type} from different angles.
        Provide a concise visual description focusing on:
        1. Art style (photorealistic, cartoon, anime, painted, etc.)
        2. Key visual characteristics (colors, textures, distinctive features)
        3. Overall appearance and mood
        
        Write a single descriptive sentence that captures the essence of this {subject_type}'s visual style.
        Focus on what makes this version distinctive from other {subject_type} representations.
        """
        
        response = self.model.generate_content([prompt] + image_parts)
        return response.text.strip()
    
    def process_new_checkpoint(
        self,
        checkpoint_path: str,
        subject_type: str,
        output_metadata_path: str
    ) -> str:
        """
        Complete pipeline to generate description for a new checkpoint.
        """
        # Generate sample images
        images = self.generate_sample_images(checkpoint_path)
        
        # Extract description using VLM
        description = self.extract_description_from_images(images, subject_type)
        
        # Save to metadata
        metadata = {
            "description": description,
            "subject_types": [subject_type],
            "visual_extraction_method": "gemini_vlm",
            "num_samples_used": len(images)
        }
        
        with open(output_metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        return description


